using LinearAlgebra

struct IdentityMatrix end

sparse(i::IdentityMatrix) = SimpleSparse(Dict((0, 0) => 1))

matrix(i::IdentityMatrix, T) = I(T)

Base.show(io::IO, i::IdentityMatrix) = print(io, "IdentityMatrix")

Base.:+(i::IdentityMatrix) = i
Base.:-(i::IdentityMatrix) = -sparse(i)

Base.:*(i::IdentityMatrix, j::AbstractMatrix) = deepcopy(j)
Base.:*(i::AbstractMatrix, j::IdentityMatrix) = deepcopy(i)


mutable struct SimpleSparse
    elements
    indices
    xs

    SimpleSparse(elements) = new(elements, nothing, nothing)
end

for f ∈ (:+, :-, :*)
    @eval Base.$f(i::IdentityMatrix, j::IdentityMatrix) = $f(sparse(i), sparse(j)) 
    @eval Base.$f(i::IdentityMatrix, j::SimpleSparse) = $f(sparse(i), j) 
    @eval Base.$f(i::SimpleSparse, j::IdentityMatrix) = $f(i, sparse(j)) 
    @eval Base.$f(i::IdentityMatrix, j) = $f(sparse(i), j) 
    @eval Base.$f(i, j::IdentityMatrix) = $f(i, sparse(j)) 
end

Base.promote_rule(::Type{SimpleSparse}, ::Type{AbstractArray}) = SimpleSparse

from_simple_diagonals(elements) = SimpleSparse(Dict((i, 0) => x for (i, x) ∈ elements))

matrix(s::SimpleSparse, T) = s + zeros(T, T)

function array!(s::SimpleSparse) 
    if !(s.indices isa Nothing)
        return s.indices, s.xs
    else
        indices, xs = zip(s.elements...)
        s.indices = [indices[i][j] for i ∈ 1:length(indices), j ∈ 1:2]
        s.xs = collect(xs)
        return s.indices, s.xs
    end
end

Base.transpose(s::SimpleSparse) = SimpleSparse(Dict((-i, m) => x for ((i, m), x) ∈ s.elements))

Base.isempty(s::SimpleSparse) = isempty(s.elements)

function nonzero(s::SimpleSparse)
    elements = copy(s.elements)
    for (i, x) ∈ s.elements
        if abs(elements[i]) < 1e-14
            pop!(elements, i)
        end
    end
    return SimpleSparse(elements)
end

iszero(s::SimpleSparse) = isempty(nonzero(s).elements)

Base.:+(s::SimpleSparse) = s
Base.:-(s::SimpleSparse) = SimpleSparse(Dict(i => -x for (i, x) ∈ s.elements))

Base.:*(s::SimpleSparse, t::SimpleSparse) = SimpleSparse(multiply_rs_rs(s, t))
Base.:*(s::SimpleSparse, t::AbstractVector) = multiply_rs_matrix(array!(s)..., reshape(t, (size(t)..., 1)))[:, 1]
Base.:*(s::SimpleSparse, t::AbstractMatrix) = multiply_rs_matrix(array!(s)..., t)
Base.:*(s::AbstractArray, t::SimpleSparse) = transpose(transpose(t) * transpose(s))

function Base.:+(s::SimpleSparse, t::SimpleSparse)
    elements = copy(s.elements)
    for (i, x) ∈ t.elements
        if i ∈ keys(elements)
            elements[i] += x
            if abs(elements[i]) < 1e-14
                pop!(elements, i)
            end
        else
            elements[i] = x
        end
    end
    return SimpleSparse(elements)
end

function Base.:+(s::SimpleSparse, t::AbstractMatrix)
    T = size(t)[1]
    t = vec(copy(t)')
    for ((i, m), x) ∈ s.elements
        if i < 0
            t[T * (-i) + (T + 1) * m + 1: T + 1 : end] .+= x
        else
            t[i + (T + 1) * m + 1: T + 1 : (T - i) * T + 1] .+= x
        end
    end
    return reshape(t, (T, T))'
end
Base.:+(s::AbstractMatrix, t::SimpleSparse) = t + s

Base.:-(s::SimpleSparse, t::SimpleSparse) = s + (-t)
Base.:-(s::SimpleSparse, t) = s + (-t)
Base.:-(s, t::SimpleSparse) = s + (-t)

Base.:*(s::SimpleSparse, t::Number) = SimpleSparse(Dict(i => t * x for (i, x) ∈ s.elements))
Base.:*(s::Number, t::SimpleSparse) = t * s

Base.:(==)(s::SimpleSparse, t::SimpleSparse) = s.elements == t.elements

Base.show(io::IO, s::SimpleSparse) = print(io, "SimpleSparse({$(join(["($i, $m) => $x " for ((i, m), x) ∈ s.elements], ", "))})")

function multiply_basis(t1, t2)
    i, m = t1
    j, n = t2
    k = i + j
    if i ≥ 0
        if j ≥ 0
            l = max(m, n - i)
        elseif k ≥ 0
            l = max(m, n - k)
        else
            l = max(m + k, n)
        end
    else
        if j ≤ 0
            l = max(m + j, n)
        else
            l = max(m, n) + min(-i, j)
        end
    end
    return k, l
end

function multiply_rs_rs(s1, s2)
    elements = Dict()
    for (i, x) ∈ s1.elements
        for (j, y) ∈ s2.elements
            k = multiply_basis(i, j)
            if k ∈ keys(elements)
                elements[k] += x * y
            else
                elements[k] = x * y
            end
        end
    end
    return elements
end

function multiply_rs_matrix(indices, xs, A)
    n = size(indices)[1]
    T = size(A)[1]
    S = size(A)[2]
    Aout = zeros(T, S)

    for count ∈ collect(1:n)
        i = indices[count, 1]
        m = indices[count, 2]
        x = xs[count]

        if i == 0
            for t ∈ collect(m+1:T)
                for s ∈ collect(1:S)
                    Aout[t, s] += x * A[t, s]
                end
            end
        elseif i > 0
            for t ∈ collect(m+1:T-i)
                for s ∈ collect(1:S)
                    Aout[t, s] += x * A[t + i, s]
                end
            end
        else
            for t ∈ collect(m-i+1:T)
                for s ∈ collect(1:S)
                    Aout[t, s] += x * A[t + i, s]
                end
            end
        end
    end
    return Aout
end

make_matrix(A, T) = matrix(A, T)
make_matrix(A::AbstractArray, T) = A
